import os
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
from scipy.special import lambertw
from contextlib import nullcontext
from typing import List, Dict, Any, Callable,Tuple

def calc_Fd_and_histNormCounts(
    full_prec_array: np.ndarray, 
    edges: np.ndarray, 
    x_min: float,
    log_ratio_edges: np.ndarray
) -> tuple:
    """
    Calculates the fraction of dry days over a given area and time period, percentage of rain rates 
    below dry-day threshold, as well as normalized histogram counts of precipitation values 
    within specified bins for each grid cell.

    Parameters:
    full_prec_array (numpy.ndarray): A 3D array with dimensions (time steps, rows, columns) representing precipitation data.
    edges (numpy.ndarray): A 1D array of bin edges for the histogram.
    x_min (float): Smallest value captured on the x-axis, limits mask for vals > 0, but <= x_min
    log_ratio_edges (numpy.ndarray): A 1D array of the ratio of the bin boundaries.

    Returns:
    tuple: A tuple containing:
        - Fd (float): The fraction of dry days.
        - lessRmin (float): The fraction of non-dry days greater than zero but below dry day threshold.
        - hist_counts (numpy.ndarray): A 3D array with dimensions (rows, columns, num_bins) representing
                                       normalized histogram counts for each grid point. At each grid point
                                       the counts are divided by (total time steps * total area * log_ratio_edges).
    """
    num_bins = len(edges) - 1
    smallest_prec = x_min
    tot_tsteps, rows, columns = full_prec_array.shape
    tot_area = rows * columns

    # Initialize histogram container
    hist_counts = np.zeros((rows, columns, num_bins))

    # Binary masks for dry and near-dry events
    binary_zero = full_prec_array == 0
    binary_rMin = (full_prec_array > 0) & (full_prec_array < smallest_prec)

    # Replace near-zero values with NaN for histogram calculation
    masked_prec = np.where(binary_rMin, np.nan, full_prec_array)

    def normalized_sum(mask: np.ndarray) -> float:
        return np.sum(mask) / (tot_tsteps * tot_area)

    # Calculate zero and near-zero fractions
    Fd = normalized_sum(binary_zero)
    lessRmin = normalized_sum(binary_rMin)

    # Histogram computation
    for row in range(rows):
        for col in range(columns):
            vector_distribution = masked_prec[:, row, col]
            counts, _ = np.histogram(vector_distribution, edges)
            hist_counts[row, col, :] = counts / (tot_tsteps * tot_area * log_ratio_edges)

    return Fd, lessRmin, hist_counts



def histNormCounts_with_Lambert(full_prec_array, edges,Lambert_dx):
    """
    Calculates the normalized histogram counts of precipitation values
    within specified bins for each grid cell, using Lambert W scaling for bin normalization.

    Parameters:
    ----------
    full_prec_array : numpy.ndarray
        A 3D array with dimensions (time steps, rows, columns) representing precipitation data.
    edges : numpy.ndarray
        A 1D array of bin edges for the histogram.
    Lambert_dx: numpy.ndarray
        A 1D array of Lambert dx for the histogram, has one less element than "edges" array

    Returns:
    -------
    np.darray:
        - hist_counts (numpy.ndarray): A 3D array with dimensions (rows, columns, num_bins) representing
                                       normalized histogram counts for each grid cell.
    """

    # Step 1: Determine dimensions and initialize placeholders
    num_bins = len(edges) - 1
    total_time_steps, rows, columns = full_prec_array.shape
    total_area = rows * columns

    hist_counts = np.zeros((rows, columns, num_bins))

    # Step 2: Iterate over grid cells to calculate histogram counts
    for row in range(rows):
        for col in range(columns):
            vector_distribution = full_prec_array[:, row, col]

            # Histogram calculation
            counts, _ = np.histogram(vector_distribution, bins=edges)
            hist_counts[row, col, :] = counts / (total_time_steps * total_area * Lambert_dx) #this is the frequency distribution


    # Step 3: Return results
    return hist_counts


def get_x_axis_arrays():
    """
    Returns edges and midpoints of x-axis structure. Opens numpy data structure, smh.
    """
    open_path = os.path.join('/Volumes', 'COO', 'MFP_NJ', 'RESULTS', 'FILES', 'HIST', 'W_Fd')
    fname = 'bin_edges_centers_seven_percent.npz'
    list_w_arrays = np.load(os.path.join(open_path,fname))
    return list_w_arrays['edges'],list_w_arrays['centers']


def LWF_xaxis(C, b, x_cutoff):
    """
    Generate a series based on the input C value, stopping at x_cutoff.

    Parameters:
    ----------
    C : float
        The scaling factor for the series.
    b : float
        A constant offset value.
    x_cutoff : float
        The cutoff value at which the series stops.

    Returns:
    -------
    np.ndarray
        The generated series as a NumPy array.
    """
    n = 0
    series = []

    while True:
        # Calculate the current value in the series
        current_value = (C * n + b) * np.exp(C * n + b)
        
        # Append the value to the series
        series.append(current_value)
        
        # Stop if the value exceeds or matches the cutoff
        if current_value >= x_cutoff:
            break
        
        # Increment n to continue building the series
        n += 1

    return np.array(series)



def calculate_log_ratios(arr):
    """
    Used for getting del log R--nearly constant rate at which rain rate bin widths increase.
    """
    if len(arr) < 2:
        raise ValueError("Array must contain at least two elements to calculate logarithmic ratios.")
    
    log_ratios = []
    for i in range(1, len(arr)):
        previous = arr[i - 1]
        current = arr[i]
        if current == 0:
            raise ValueError("Division by zero encountered in the array.")
        log_ratio = np.log(current / previous)
        log_ratios.append(log_ratio)
    
    return np.array(log_ratios) #are these the widths of the bins you need to use for normalization?--yes


def calculate_lambert_dx(arr):
    """
    Calculate the Lambert W transformation of the ratios of consecutive elements in an array.

    This function computes the ratio of each consecutive pair of elements in the input array
    and applies the Lambert W function to the result. It skips the first element, assuming
    the array may start from zero, and avoids division by zero.

    Parameters:
    ----------
    arr : array-like
        A 1D array of numeric values. The array must contain at least two elements.

    Returns:
    -------
    np.ndarray
        A 1D array of Lambert W-transformed ratios of consecutive elements.

    Raises:
    ------
    ValueError:
        If the array has fewer than two elements.
        If a division by zero is encountered in the array.

    Example:
    -------
    >>> arr = [0, 1, 2, 3, 4]
    >>> calculate_lambert_ratios(arr)
    array([...])
    """
    # Ensure the array has at least two elements
    if len(arr) < 2:
        raise ValueError("Array must contain at least two elements to calculate Lambert W ratios.")

    # Initialize a list to store Lambert W-transformed differences
    lambert_differences = []

    # Iterate through the array, starting from the second element
    for i in range(1, len(arr)):  # Don't skip the first element (index 0)
        previous = arr[i - 1]
        current = arr[i]
        # apply the Lambert W function and calculate the difference
        delW = lambertw(current).real - lambertw(previous).real
        lambert_differences.append(delW)

    lambert_array = np.array(lambert_differences)

    # Return the results as a NumPy array
    return lambert_array

def calculate_midpoints(arr, transform_to_LW=False):
    """
    Calculate midpoints for an array in either the original or Lambert W-transformed space.

    Parameters:
    ----------
    arr : np.ndarray
        The input array of bin edges.
    transform_to_LW : bool
        If True, calculate midpoints in the Lambert W-transformed space.

    Returns:
    -------
    np.ndarray
        Array of midpoints.
    """
    if len(arr) < 2:
        raise ValueError("Array must contain at least two elements to calculate midpoints.")
    
    if transform_to_LW:
        # Transform edges first, then calculate midpoints
        transformed = lambertw(arr).real
        midpoints = np.sqrt(transformed[:-1]*transformed[1:])
    else:
        # Calculate midpoints in original space
        midpoints = np.sqrt(arr[:-1]*arr[1:])
    
    return midpoints
  


from contextlib import nullcontext
from typing import Callable, Tuple

def process_single_surface_temperature_case(
    spath: str,
    diag_array: np.ndarray,
    model_name: str,
    surf_temp: str,
    bin_edges: np.ndarray,
    x_min: float,
    log_xspaces: np.ndarray,
    centros: np.ndarray,
    Func2Process: Callable,
    outFileName: str,
    outFile: bool = False
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Processes a single surface temperature case for a given model, calculates frequency distributions,
    values greater than zero but less than a threshold, and normalizes histograms.

    Parameters:
    - spath (str): Output directory path.
    - diag_array (np.ndarray): Diagnostic precipitation array (3D numpy array).
    - model_name (str): Name of the model being processed.
    - surf_temp (str): The surface temperature case to process.
    - bin_edges (np.ndarray): 1D array of bin edges for histogram calculation.
    - x_min (float): Minimum value displayed on the horizontal.
    - log_xspaces (np.ndarray): 1D array of log-scaled spacing for bin positions.
    - centros (np.ndarray): 1D array of bin centers for histogram calculations.
    - Func2Process (Callable): Function to process the diagnostic data.
    - outFileName (str): Name of the output file for saving information.
    - outFile (bool, optional): Whether to save output information to a file. Defaults to False.

    Returns:
    - Tuple[np.ndarray, np.ndarray]: Frequency array and amount array for the given surface temperature case.
    """

    freq_array, amount_array = None, None
    output_file_path = os.path.join(spath, f"{outFileName}.txt")

    def write_to_file(content, file):
        if outFile:
            file.write(content + '\n')

    with open(output_file_path, 'a') if outFile else nullcontext() as output_file:
        write_to_file(f"[{model_name} | TEMP: {surf_temp}]", output_file)

        # Call the user-supplied processing function
        Fd, lessRmin, norm_hist = Func2Process(
            diag_array, bin_edges, x_min, log_xspaces
        )

        # Aggregate histogram across grid
        hist_to_mult = np.sum(norm_hist, axis=(0, 1))
        freq_array = hist_to_mult
        hist_sum = np.sum(freq_array * log_xspaces)
        tot_area = Fd + lessRmin + hist_sum

        # Logging diagnostic information
        write_to_file(f"Zero events (Fd): {np.round(Fd, 6)}", output_file)
        write_to_file(f"Near-zero but non-zero (lessRmin): {np.round(lessRmin, 6)}", output_file)
        write_to_file(f"Histogram contribution: {np.round(hist_sum, 6)}", output_file)
        write_to_file(f"Total area: {np.round(tot_area, 6)}", output_file)

        if abs(tot_area - 1) <= 0.001:
            write_to_file("Total area validation passed.", output_file)
        else:
            write_to_file(f"Warning: Total area does not add to 1 (value: {np.round(tot_area, 6)}).", output_file)

        # Compute amount array
        amount_array = freq_array * centros

    return freq_array, amount_array




def process_surface_temperature_cases(
    spath: str, 
    diag_arrays: list, 
    surf_temp_cases: list, 
    bin_edges: np.ndarray, 
    x_min: float,
    log_xspaces: np.ndarray, 
    centros: np.ndarray,
    Func2Process: Callable,
    outFileName: str, 
    model_name: str
) -> tuple:
    """
    Processes surface temperature cases for a given model, logs diagnostic details, and computes
    frequency and amount distributions.

    Parameters:
    - spath: Output directory path.
    - diag_arrays: List of diagnostic precipitation arrays (3D numpy arrays).
    - surf_temp_cases: List of surface temperature case labels.
    - bin_edges: 1D array of bin edges for histogram calculation.
    - x_min: Minimum value to consider as significant precipitation.
    - log_xspaces: 1D array of log-scaled bin widths.
    - centros: 1D array of bin centers.
    - Func2Process: Function that returns Fd, lessRmin, and normalized histogram counts.
    - outFileName: Name of the output text file for diagnostics.
    - model_name: Identifier for the model being processed.

    Returns:
    - Tuple of two lists: frequency arrays and amount arrays per surface temperature case.
    """

    freq_arrays, amount_arrays = [], []
    output_file_path = os.path.join(spath, f"{outFileName}.txt")

    with open(output_file_path, 'a') as output_file:
        for diag_array, surf_temp in zip(diag_arrays, surf_temp_cases):
            output_file.write(f"[{model_name} | TEMP: {surf_temp}]\n")

            Fd, lessRmin, norm_hist = Func2Process(
                diag_array, bin_edges, x_min, log_xspaces
            )

            hist_to_mult = np.nansum(norm_hist, axis=(0, 1))
            freq_arrays.append(hist_to_mult)

            hist_sum = np.sum(hist_to_mult * log_xspaces)
            tot_area = Fd + lessRmin + hist_sum

            # Log diagnostic values
            output_file.write(f"Zero events (Fd): {np.round(Fd, 6)}\n")
            output_file.write(f"Near-zero but non-zero (lessRmin): {np.round(lessRmin, 6)}\n")
            output_file.write(f"Histogram contribution: {np.round(hist_sum, 6)}\n")
            output_file.write(f"Total area: {np.round(tot_area, 6)}\n")

            if abs(tot_area - 1) <= 0.001:
                output_file.write("Total area validation passed.\n\n")
            else:
                output_file.write(f"Warning: Total area does not add to 1 (value: {np.round(tot_area, 6)}).\n\n")

            # Always compute and store the amount array
            amount_array = hist_to_mult * centros
            amount_arrays.append(amount_array)

    return freq_arrays, amount_arrays

def process_surface_temperature_cases_with_Lambert(
    spath: str,
    diag_array: np.ndarray,
    surf_temp: str,
    bin_edges: np.ndarray,
    Lambert_dx: np.ndarray,
    centros: np.ndarray,
    Func2Process: callable,
    outFile: bool = False,
    outFileName: str = None,
) -> tuple:
    """
    Processes surface temperature case, calculates frequency distribution, and normalizes histograms
    using Lambert W scaling.

    Parameters:
    ----------
    spath : str
        Output directory path.
    diag_array : array
        A 3D numpy array.
    surf_temp : str
        A surface temperature cases.
    bin_edges : numpy.ndarray
        1D array of bin edges for histogram calculation.
    Lambert_dx : numpy.ndarray
        1D array of Lambert W-scaled dx for bin normalization.
    centros : numpy.ndarray
        1D array of bin centers in precipitation space for amount distribution calculation.
    Func2Process : callable
        Function that calculates histogram and dry day fraction (e.g., calc_Fd_and_histNormCounts_with_Lambert).
    outFile : bool
        If True, outputs processing information to a file.
    outFileName : str
        Name of the output file for saving information.

    Returns:
    -------
    tuple:
        - freq_array : numpy array
            A frequency array for a surface temperature case.
        - amount_arrays : numpy array
            A amount array for a surface temperature case.
    """
    freq_array, amount_array = list(), list()
    
    if outFile and outFileName:
        output_file_path = os.path.join(spath, f"{outFileName}.txt")
        with open(output_file_path, 'w') as output_file:
            output_file.write(f'Processing surface temp case: {surf_temp}.\n')
    
            # Call Func2Process with Lambert_ratios
            norm_PDF = Func2Process(diag_array, bin_edges, Lambert_dx)
    
            # Calculate frequency array and total area
            hist_to_mult = np.sum(norm_PDF, axis=(0, 1))
            freq_array.append(hist_to_mult)
    
            tot_area = np.sum(hist_to_mult * Lambert_dx)
            output_file.write(f'Total Area: {tot_area}\n')

            # Check total area condition
            if abs(tot_area - 1) <= 0.001:
                amount = hist_to_mult * centros
                amount_array.append(amount)
                print('Pass!')
            else:
                amount = hist_to_mult * centros
                amount_array.append(amount)
                warning_msg = f"Be wary - area is not adding to one--its {np.round(tot_area, 3)}."
                print(warning_msg)
                output_file.write(f"{warning_msg}\n")
    else:
        norm_PDF = Func2Process(diag_array, bin_edges, Lambert_dx)
        hist_to_mult = np.sum(norm_PDF, axis=(0, 1)) #sum over rows and columns 
        freq_array.append(hist_to_mult)
    
        tot_area = np.sum(hist_to_mult * Lambert_dx)
        if abs(tot_area - 1) <= 0.001:
            amount = hist_to_mult * centros
            amount_array.append(amount)
            print('Pass!')
        else:
            amount = hist_to_mult * centros
            amount_array.append(amount)
            warning_msg = f"Be wary - area is not adding to one--its {np.round(tot_area, 3)}."
            print(warning_msg)



    return np.array(freq_array).reshape(-1), np.array(amount_array).reshape(-1)


def create_dataset(
    var_names: List[str],
    amount_list: List[Any],
    coord_names: List[str],
    coord_values: List[Any],
    attrs_dic: Dict[str, Any]
) -> xr.Dataset:
    """
    Create an xarray.Dataset with the given variable names, data, coordinates, and attributes.

    Parameters:
    - var_names: List of names for the data variables (e.g., ['a280K', 'a290K', ...]).
    - amount_list: List of lists/arrays corresponding to each variable's data.
    - coord_names: List of coordinate names (e.g., ['bin_vals', 'midpoints', 'logx']).
    - coord_values: List of lists/arrays corresponding to the coordinates.
    - attrs_dic: Dictionary of attributes (metadata) for the dataset.

    Returns:
    - ds: The created xarray.Dataset.
    """
    # Ensure dimensions and coordinates align properly
    dimensions = coord_names  # Assume dimensions are the same as coord_names
    
    # Build data variables with appropriate dimensions
    dict_dvars = {
        var_name: (dimensions, amount_data)
        for var_name, amount_data in zip(var_names, amount_list)
    }

    # Build coordinates dictionary
    coords = {name: values for name, values in zip(coord_names, coord_values)}

    # Create the xarray Dataset
    ds = xr.Dataset(data_vars=dict_dvars, coords=coords, attrs=attrs_dic)

    return ds


def geo_series(r,x_ini,x_max):
    if r <= 1:
        raise ValueError("The growth rate 'r' must be greater than 1 to avoid an infinite loop.")
    
    holder = list()
    
    while x_ini <= x_max:
        holder.append(x_ini)  # Append current value to the series
        x_ini = x_ini * r     # Multiply by the constant ratio
    
    return np.array(holder)

def midpoint_giver(array):
  
    return np.array([np.sqrt(array[i+1] * array[i]) for i in range(len(array) - 1)])
    
def calc_b1(base, compare, grad_base):
    delta_pm = compare - base
    num = -np.nansum(grad_base*delta_pm)
    deno = np.nansum(np.square(grad_base))
    b = num/deno
    return b

def calc_a(base, compare):
    delta_pm = compare - base
    num = np.nansum(base*delta_pm)
    deno = np.nansum(np.square(base))
    a = num/deno
    return a


def grad_calculator(array, midpoints, LWF = None):
    """
    Calculate the gradient of an array with respect to given midpoints.
    
    Parameters:
    array (list or np.ndarray): The input array for which the gradient is calculated.
    midpoints (list or np.ndarray): The midpoints corresponding to the array values.
    
    Returns:
    np.ndarray: The gradient of the array.
    """
     
    n = len(array)
    
    gradients = np.zeros(n)
    if LWF:
        
        # First value gradient
        gradients[0] = (array[1] - array[0]) / (lambertw(midpoints[1]).real - lambertw(midpoints[0]).real )

        # Middle values gradients
        gradients[1:-1] = (array[2:] - array[:-2]) / ((lambertw(midpoints[2:]).real - lambertw(midpoints[:-2]).real))


        # Last value gradient
        gradients[-1] = (array[-1] - array[-2]) / (lambertw(midpoints[-1]).real - lambertw(midpoints[-2]).real)
    else:
         # First value gradient
        gradients[0] = (array[1] - array[0]) / (np.log(midpoints[1]) - np.log(midpoints[0]) )

        # Middle values gradients
        gradients[1:-1] = (array[2:] - array[:-2]) / (( np.log(midpoints[2:]) - np.log(midpoints[:-2])) )
    
        # Last value gradient
        gradients[-1] = (array[-1] - array[-2]) / (np.log(midpoints[-1]) - np.log(midpoints[-2]))

    return gradients

def transform_keys(data_dict, grad_dict, shift=True):
    """
    Transforms a dictionary by combining consecutive keys and applying a shift function.

    Parameters:
    - data_dict (dict): A dictionary where keys are strings (e.g., '280K') and values are arrays.
    - grad_dict (dict): A dictionary with the same keys as `data_dict` and gradient values.
    - shift (boolean): default parameter calcualtion, if false, calculate the increase mode

    Returns:
    - dict: A transformed dictionary with combined keys and values from the `shift` function.
    Note: assumes calc_b1 & calc_a are in the same script and accessible. 
    """
    # Sort keys numerically based on the numeric part of the key
    sorted_keys = sorted(grad_dict.keys(), key=lambda x: int(x[:-1]))  # Remove 'K' and sort numerically

    result_dict = {}

    for i in range(len(sorted_keys) - 1):
        # Get consecutive keys and their data
        key1, key2 = sorted_keys[i], sorted_keys[i + 1]
        data1, data2 = data_dict[key1], data_dict[key2] 
        grad1 = grad_dict[key1]

        # Apply the shift function
        if shift:
            result = calc_b1(data1, data2, grad1)
        else:
            result = calc_a(data1,data2)

        # Create a new combined key
        new_key = f"{key1[:-1]}-{key2}"

        # Add to the result dictionary
        result_dict[new_key] = result

    return result_dict

def plot_shifted_log_dist(a_arrays,x_vals, xticks, norm_case, key, colors, shift_value, array_shift_index, array_compare_index, save_path):
    fig, ax = plt.subplots(figsize=(8, 6))

    # Apply transformations and create plots
    ax.stairs(a_arrays[array_shift_index],x_vals * np.exp(shift_value), edgecolor=colors[array_shift_index], linewidth=1,
              label=f"{key[:3]}K shifted by {np.round(shift_value, 2)}.")
    ax.stairs(a_arrays[array_compare_index], x_vals, edgecolor=colors[array_compare_index], linewidth=1, label=f" {key[-4:]}")
    ax.set_xscale('log')
    ax.set_xticks(xticks)
    ax.set_xticklabels(xticks, rotation=45, fontsize=12)
    ax.set_ylim(0,0.50)
    ax.set_xlim(np.min(xticks), np.max(xticks)+ np.max(xticks)*10)
    ax.set_xlabel(norm_case, fontsize=12)
    ax.set_ylabel(f"[Prob / del LWF( {norm_case} )]", fontsize=12)
    ax.legend(loc='best', fontsize=10)

    plt.tight_layout()
    plt.savefig(save_path)
    plt.close(fig)

def plot_shifted_LWF_dist(a_arrays,x_vals, xticks, norm_case, key, colors, shift_value, array_shift_index, array_compare_index, save_path):
    
    labels_transformed = lambertw(xticks).real
    bin_pos = lambertw(x_vals).real
    
    fig, ax = plt.subplots(figsize=(8, 6))
    # Apply transformations and create plots
    ax.stairs(a_arrays[array_shift_index],bin_pos * np.exp(shift_value), edgecolor=colors[array_shift_index], linewidth=1,
              label=f"{key[:3]}K shifted by {np.round(shift_value, 2)}.")
    ax.stairs(a_arrays[array_compare_index], bin_pos, edgecolor=colors[array_compare_index], linewidth=1, label=f" {key[-4:]}")
    ax.set_xticks(labels_transformed)
    ax.set_xticklabels(xticks, rotation=45, fontsize=12)
    ax.set_ylim(0,0.50)
    ax.set_xlim(np.min(xticks), np.max(xticks)+ np.max(xticks)*0.10)
    ax.set_xlabel(norm_case, fontsize=12)
    ax.set_ylabel(f"[Prob / del LWF( {norm_case} )]", fontsize=12)
    ax.legend(loc='best', fontsize=10)

    plt.tight_layout()
    plt.savefig(save_path)
    plt.close(fig)

    
def plot_distribution_panels(
    selected_labels,
    x_vals,
    dist_dic,
    norm_case,
    save_name,
    save_dir,
    save_flag,
    LWF=False
):
    # Handle Lambert W function transformation
    if LWF:
        transformed_bins = lambertw(x_vals).real
        selected_transformed = lambertw(selected_labels).real
        x_label_suffix = f'del LWF({norm_case})'
    else:
        transformed_bins = x_vals
        selected_transformed = selected_labels
        x_label_suffix = f'del log({norm_case})'
    
    # Define colorblind-friendly colors
    colors = ['#377eb8', '#4daf4a', '#984ea3', '#e41a1c']
    
    # Create a 1x4 panel plot
    fig, axes = plt.subplots(1, 4, figsize=(20, 6), sharey=True)
    
    for i, (temp, ax) in enumerate(zip(dist_dic.keys(), axes)):
        color = colors[i % len(colors)]
        
        # Plot frequency and amount distributions
        ax.stairs(dist_dic[temp][0], transformed_bins, edgecolor=color, linewidth=1, label=f'{temp} Frequency')
        ax.stairs(dist_dic[temp][1], transformed_bins, edgecolor='k', linewidth=1, label=f'{temp} Amount')

        
        
        if LWF:
            ax.set_xticks(selected_transformed)
            ax.set_xticklabels(selected_labels, rotation=45, fontsize=12)
            ax.set_xlim(0, lambertw(np.max(selected_labels)).real)
        else:
            ax.set_xlim(0, np.max(selected_labels))
        
        ax.set_xlabel(norm_case, fontsize=14)
        
        if i == 0:
            ax.set_ylabel(f'[Prob. / {x_label_suffix}]', fontsize=14)
        
        if not LWF:
            ax.set_xscale('log')
        
        ax.legend(loc='best', fontsize=13)
    
    plt.tight_layout()
    
    # Handle saving and displaying
    if save_flag:
        plt.savefig(os.path.join(save_dir, save_name))
    plt.show()


def plot_LWF_amount_panel(selected_labels, x_LWF, dist_dict, norm_prec_dic, norm_case, sfig, file_name, save_flag=False):
    """
    Plots the famount distributions for different temperature cases on the same panel.

    Parameters:
    selected_labels (list): Tick labels for the x-axis.
    x_LWF (array): X-axis values to be transformed using Lambert W.
    dist_dict (dict): Dictionary containing frequency and amount data for temperature cases.
    norm_prec_dic (dict): Normalization values for each temperature case.
    norm_case (str): Normalization case. 
    sfig (str): Directory path to save the plot.
    file_name (str): The filename to save the plot.
    save_flag (bool): Flag to determine whether to save the plot.
    """
    colors = ['#377eb8', '#4daf4a', '#984ea3', '#e41a1c']  # Colorblind-friendly palette
    selected_transformed = lambertw(selected_labels).real

    fig, ax = plt.subplots(figsize=(10, 6))  # Create a single panel plot
    transformed_bins = lambertw(x_LWF).real


    # Plot amount curves for each temperature case
    for i, temp in enumerate(['280K', '290K', '300K', '310K']):
        ax.stairs(
            dist_dict[temp][1] / np.mean(norm_prec_dic[temp]),
            lambertw(x_LWF).real,
            edgecolor=colors[i],
            linewidth=1,
            label=f'{temp}'
        )

    ax.set_xticks(selected_transformed)
    ax.set_xticklabels(selected_labels, rotation=45, fontsize=12)
    ax.set_xlabel(norm_case, fontsize=14)
    ax.set_ylabel(f'[Prob. / del LWF( {norm_case} )]', fontsize=14)
    ax.legend(loc='best', fontsize=13)

    plt.tight_layout()
    if save_flag:
        plt.savefig(os.path.join(sfig, file_name))
    plt.show()

def A_v_provider(amount, amount_model,dplnr):
    """
    Compute matrix A and vector v for a given amount distribution and model.

    This function calculates the elements of matrix A and vector v based on the
    provided `amount`,  `amount_model`, and `dplnr`.

    Parameters:
    -----------
    amount : array-like
        The array representing the amount distribution.
    amount_model : array-like
        The target array for the amount distribution.
    dplnr: array-like
        The gradient of the base distribution.

    Returns:
    --------
    A : numpy.ndarray
        A 2x2 matrix computed based on the provided parameters.
    v : numpy.ndarray
        A 2x1 vector computed based on the provided parameters.
    """

    # Compute the dp/dlogr term
    dp_dLogr = dplnr

    # Calculate A[0,0] term
    A_zero_zero = np.sum(np.square(amount))
    
    # Calculate v[0,1] term
    #amount_model is 290K, amount is 280K, delta pm is their difference
    v_zero_one = np.sum(amount * (amount_model - amount))

    # Find v[1,0]
    v_one_zero = -np.sum(dp_dLogr * (amount_model - amount))
    
    # Terms A[0,1] and A[1,0] are the same
    A_diag_terms = -np.sum(amount * dp_dLogr)
    
    # Find A[1,1] term
    A_one_one = np.sum(np.square(dp_dLogr))


    A = np.array([[A_zero_zero, A_diag_terms], [A_diag_terms, A_one_one]])
    v = np.array([v_zero_one, v_one_zero]).reshape(-1, 1)

    return A, v